#############################################################
# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion
#############################################################

from typing import List, Dict, Any
import torch
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate
from ding.torch_utils import to_device
from .base_policy import Policy


@POLICY_REGISTRY.register('qgpo')
class QGPOPolicy(Policy):
    """
    Overview:
        Policy class of QGPO algorithm (https://arxiv.org/abs/2304.12824).
        Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning
    Interfaces:
        ``__init__``, ``forward``, ``learn``, ``eval``, ``state_dict``, ``load_state_dict``
    """

    config = dict(
        # (str) RL policy register name (refer to function "POLICY_REGISTRY").
        type='qgpo',
        # (bool) Whether to use cuda for network.
        cuda=False,
        # (bool type) on_policy: Determine whether on-policy or off-policy.
        # on-policy setting influences the behaviour of buffer.
        # Default False in QGPO.
        on_policy=False,
        multi_agent=False,
        model=dict(
            qgpo_critic=dict(
                # (float) The scale of the energy guidance when training qt.
                # \pi_{behavior}\exp(f(s,a)) \propto \pi_{behavior}\exp(alpha * Q(s,a))
                alpha=3,
                # (float) The scale of the energy guidance when training q0.
                # \mathcal{T}Q(s,a)=r(s,a)+\mathbb{E}_{s'\sim P(s'|s,a),a'\sim\pi_{support}(a'|s')}Q(s',a')
                # \pi_{support} \propto \pi_{behavior}\exp(q_alpha * Q(s,a))
                q_alpha=1,
            ),
            device='cuda',
            # obs_dim
            # action_dim
        ),
        learn=dict(
            # learning rate for behavior model training
            learning_rate=1e-4,
            # batch size during the training of behavior model
            batch_size=4096,
            # batch size during the training of q value
            batch_size_q=256,
            # number of fake action support
            M=16,
            # number of diffusion time steps
            diffusion_steps=15,
            # training iterations when behavior model is fixed
            behavior_policy_stop_training_iter=600000,
            # training iterations when energy-guided policy begin training
            energy_guided_policy_begin_training_iter=600000,
            # training iterations when q value stop training, default None means no limit
            q_value_stop_training_iter=1100000,
        ),
        eval=dict(
            # energy guidance scale for policy in evaluation
            # \pi_{evaluation} \propto \pi_{behavior}\exp(guidance_scale * alpha * Q(s,a))
            guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0],
        ),
    )

    def _init_learn(self) -> None:
        """
        Overview:
            Learn mode initialization method. For QGPO, it mainly contains the optimizer, \
            algorithm-specific arguments such as qt_update_momentum, discount, behavior_policy_stop_training_iter, \
            energy_guided_policy_begin_training_iter and q_value_stop_training_iter, etc.
            This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
        """
        self.cuda = self._cfg.cuda

        self.behavior_model_optimizer = torch.optim.Adam(
            self._model.score_model.parameters(), lr=self._cfg.learn.learning_rate
        )
        self.q_optimizer = torch.optim.Adam(self._model.q.q0.parameters(), lr=3e-4)
        self.qt_optimizer = torch.optim.Adam(self._model.q.qt.parameters(), lr=3e-4)

        self.qt_update_momentum = 0.005
        self.discount = 0.99

        self.behavior_policy_stop_training_iter = self._cfg.learn.behavior_policy_stop_training_iter
        self.energy_guided_policy_begin_training_iter = self._cfg.learn.energy_guided_policy_begin_training_iter
        self.q_value_stop_training_iter = self._cfg.learn.q_value_stop_training_iter

    def _forward_learn(self, data: dict) -> Dict[str, Any]:
        """
        Overview:
            Forward function for learning mode.
            The training of QGPO algorithm is based on contrastive energy prediction, \
            which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
            is sampled from the action support generated by the behavior policy.
            The training process is divided into two stages:
            1. Train the behavior model, which is modeled as a diffusion model by parameterizing the score function.
            2. Train the Q function by fake action support generated by the behavior model.
            3. Train the energy-guided policy by the Q function.
        Arguments:
            - data (:obj:`dict`): Dict type data.
        Returns:
            - result (:obj:`dict`): Dict type data of algorithm results.
        """

        if self.cuda:
            data = to_device(data, self._device)

        s = data['s']
        a = data['a']
        r = data['r']
        s_ = data['s_']
        d = data['d']
        fake_a = data['fake_a']
        fake_a_ = data['fake_a_']

        # training behavior model
        if self.behavior_policy_stop_training_iter > 0:

            behavior_model_training_loss = self._model.score_model_loss_fn(a, s)

            self.behavior_model_optimizer.zero_grad()
            behavior_model_training_loss.backward()
            self.behavior_model_optimizer.step()

            self.behavior_policy_stop_training_iter -= 1
            behavior_model_training_loss = behavior_model_training_loss.item()
        else:
            behavior_model_training_loss = 0

        # training Q function
        self.energy_guided_policy_begin_training_iter -= 1
        self.q_value_stop_training_iter -= 1
        if self.energy_guided_policy_begin_training_iter < 0:
            if self.q_value_stop_training_iter > 0:
                q0_loss = self._model.q_loss_fn(a, s, r, s_, d, fake_a_, discount=self.discount)

                self.q_optimizer.zero_grad()
                q0_loss.backward()
                self.q_optimizer.step()

                # Update target
                for param, target_param in zip(self._model.q.q0.parameters(), self._model.q.q0_target.parameters()):
                    target_param.data.copy_(
                        self.qt_update_momentum * param.data + (1 - self.qt_update_momentum) * target_param.data
                    )

                q0_loss = q0_loss.item()

            else:
                q0_loss = 0
            qt_loss = self._model.qt_loss_fn(s, fake_a)

            self.qt_optimizer.zero_grad()
            qt_loss.backward()
            self.qt_optimizer.step()

            qt_loss = qt_loss.item()

        else:
            q0_loss = 0
            qt_loss = 0

        total_loss = behavior_model_training_loss + q0_loss + qt_loss

        return dict(
            total_loss=total_loss,
            behavior_model_training_loss=behavior_model_training_loss,
            q0_loss=q0_loss,
            qt_loss=qt_loss,
        )

    def _init_collect(self) -> None:
        """
        Overview:
            Collect mode initialization method. Not supported for QGPO.
        """
        pass

    def _forward_collect(self) -> None:
        """
        Overview:
            Forward function for collect mode. Not supported for QGPO.
        """
        pass

    def _init_eval(self) -> None:
        """
        Overview:
            Eval mode initialization method. For QGPO, it mainly contains the guidance_scale and diffusion_steps, etc.
            This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
        """

        self.diffusion_steps = self._cfg.eval.diffusion_steps

    def _forward_eval(self, data: dict, guidance_scale: float) -> dict:
        """
        Overview:
            Forward function for eval mode. The eval process is based on the energy-guided policy, \
            which is modeled as a diffusion model by parameterizing the score function.
        Arguments:
            - data (:obj:`dict`): Dict type data.
            - guidance_scale (:obj:`float`): The scale of the energy guidance.
        Returns:
            - output (:obj:`dict`): Dict type data of algorithm output.
        """

        data_id = list(data.keys())
        states = default_collate(list(data.values()))
        actions = self._model.select_actions(
            states, diffusion_steps=self.diffusion_steps, guidance_scale=guidance_scale
        )
        output = actions

        return {i: {"action": d} for i, d in zip(data_id, output)}

    def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Overview:
            Get the train sample from the replay buffer, currently not supported for QGPO.
        Arguments:
            - transitions (:obj:`List[Dict[str, Any]]`): The data from replay buffer.
        Returns:
            - samples (:obj:`List[Dict[str, Any]]`): The data for training.
        """
        pass

    def _process_transition(self) -> None:
        """
        Overview:
            Process the transition data, currently not supported for QGPO.
        """
        pass

    def _state_dict_learn(self) -> Dict[str, Any]:
        """
        Overview:
            Return the state dict for saving.
        Returns:
            - state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict.
        """
        return {
            'model': self._model.state_dict(),
            'behavior_model_optimizer': self.behavior_model_optimizer.state_dict(),
        }

    def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
        """
        Overview:
            Load the state dict.
        Arguments:
            - state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict.
        """
        self._model.load_state_dict(state_dict['model'])
        self.behavior_model_optimizer.load_state_dict(state_dict['behavior_model_optimizer'])

    def _monitor_vars_learn(self) -> List[str]:
        """
        Overview:
            Return the variables names to be monitored.
        """
        return ['total_loss', 'behavior_model_training_loss', 'q0_loss', 'qt_loss']
